{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CIFAR10 Low Precision Training Example\n",
    "In this notebook, we present a quick example of how to simulate training a deep neural network in low precision with QPyTorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import useful modules\n",
    "import argparse\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from qtorch.quant import Quantizer, quantizer\n",
    "from qtorch.optim import OptimLP\n",
    "from torch.optim import SGD\n",
    "from qtorch import FloatingPoint\n",
    "from tqdm import tqdm\n",
    "import math"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first load the data. In this example, we will experiment with CIFAR10."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "# loading data\n",
    "ds = torchvision.datasets.CIFAR10\n",
    "path = os.path.join(\"./data\", \"CIFAR10\")\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "train_set = ds(path, train=True, download=True, transform=transform_train)\n",
    "test_set = ds(path, train=False, download=True, transform=transform_test)\n",
    "loaders = {\n",
    "        'train': torch.utils.data.DataLoader(\n",
    "            train_set,\n",
    "            batch_size=128,\n",
    "            shuffle=True,\n",
    "            num_workers=4,\n",
    "            pin_memory=True\n",
    "        ),\n",
    "        'test': torch.utils.data.DataLoader(\n",
    "            test_set,\n",
    "            batch_size=128,\n",
    "            num_workers=4,\n",
    "            pin_memory=True\n",
    "        )\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then define the quantization setting we are going to use. In particular, here we follow the setting reported in the paper \"Training Deep Neural Networks with 8-bit Floating Point Numbers\", where the authors propose to use specialized 8-bit and 16-bit floating point format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define two floating point formats\n",
    "bit_8 = FloatingPoint(exp=5, man=2)\n",
    "bit_16 = FloatingPoint(exp=6, man=9)\n",
    "\n",
    "# define quantization functions\n",
    "weight_quant = quantizer(forward_number=bit_8,\n",
    "                        forward_rounding=\"nearest\")\n",
    "grad_quant = quantizer(forward_number=bit_8,\n",
    "                        forward_rounding=\"nearest\")\n",
    "momentum_quant = quantizer(forward_number=bit_16,\n",
    "                        forward_rounding=\"stochastic\")\n",
    "acc_quant = quantizer(forward_number=bit_16,\n",
    "                        forward_rounding=\"stochastic\")\n",
    "\n",
    "# define a lambda function so that the Quantizer module can be duplicated easily\n",
    "act_error_quant = lambda : Quantizer(forward_number=bit_8, backward_number=bit_8,\n",
    "                        forward_rounding=\"nearest\", backward_rounding=\"nearest\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we define a low-precision ResNet. In the definition, we recursively insert quantization module after every convolution layer. Note that the quantization of weight, gradient, momentum, and gradient accumulator are not handled here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv3x3(in_planes, out_planes, stride=1):\n",
    "    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n",
    "                     padding=1, bias=False)\n",
    "\n",
    "class BasicBlock(nn.Module):\n",
    "    expansion = 1\n",
    "\n",
    "    def __init__(self, inplanes, planes, quant, stride=1, downsample=None):\n",
    "        super(BasicBlock, self).__init__()\n",
    "        self.bn1 = nn.BatchNorm2d(inplanes)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.conv1 = conv3x3(inplanes, planes, stride)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = conv3x3(planes, planes)\n",
    "        self.downsample = downsample\n",
    "        self.stride = stride\n",
    "        self.quant = quant()\n",
    "\n",
    "    def forward(self, x):\n",
    "        residual = x\n",
    "\n",
    "        out = self.bn1(x)\n",
    "        out = self.relu(out)\n",
    "        out = self.quant(out)\n",
    "        out = self.conv1(out)\n",
    "        out = self.quant(out)\n",
    "\n",
    "        out = self.bn2(out)\n",
    "        out = self.relu(out)\n",
    "        out = self.quant(out)\n",
    "        out = self.conv2(out)\n",
    "        out = self.quant(out)\n",
    "\n",
    "        if self.downsample is not None:\n",
    "            residual = self.downsample(x)\n",
    "\n",
    "        out += residual\n",
    "\n",
    "        return out\n",
    "    \n",
    "class PreResNet(nn.Module):\n",
    "\n",
    "    def __init__(self,quant, num_classes=10, depth=20):\n",
    "\n",
    "        super(PreResNet, self).__init__()\n",
    "        assert (depth - 2) % 6 == 0, 'depth should be 6n+2'\n",
    "        n = (depth - 2) // 6\n",
    "\n",
    "        block = BasicBlock\n",
    "\n",
    "        self.inplanes = 16\n",
    "        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,\n",
    "                               bias=False)\n",
    "        self.layer1 = self._make_layer(block, 16, n, quant)\n",
    "        self.layer2 = self._make_layer(block, 32, n, quant, stride=2)\n",
    "        self.layer3 = self._make_layer(block, 64, n, quant, stride=2)\n",
    "        self.bn = nn.BatchNorm2d(64 * block.expansion)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.avgpool = nn.AvgPool2d(8)\n",
    "        self.fc = nn.Linear(64 * block.expansion, num_classes)\n",
    "        self.quant = quant()\n",
    "        IBM_half = FloatingPoint(exp=6, man=9)\n",
    "        self.quant_half = Quantizer(IBM_half, IBM_half, \"nearest\", \"nearest\")\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n",
    "                m.weight.data.normal_(0, math.sqrt(2. / n))\n",
    "            elif isinstance(m, nn.BatchNorm2d):\n",
    "                m.weight.data.fill_(1)\n",
    "                m.bias.data.zero_()\n",
    "\n",
    "    def _make_layer(self, block, planes, blocks, quant, stride=1):\n",
    "        downsample = None\n",
    "        if stride != 1 or self.inplanes != planes * block.expansion:\n",
    "            downsample = nn.Sequential(\n",
    "                nn.Conv2d(self.inplanes, planes * block.expansion,\n",
    "                          kernel_size=1, stride=stride, bias=False),\n",
    "            )\n",
    "\n",
    "        layers = list()\n",
    "        layers.append(block(self.inplanes, planes, quant , stride, downsample))\n",
    "        self.inplanes = planes * block.expansion\n",
    "        for i in range(1, blocks):\n",
    "            layers.append(block(self.inplanes, planes, quant))\n",
    "\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.quant_half(x)\n",
    "        x = self.conv1(x)\n",
    "        x = self.quant(x)\n",
    "\n",
    "        x = self.layer1(x)  # 32x32\n",
    "        x = self.layer2(x)  # 16x16\n",
    "        x = self.layer3(x)  # 8x8\n",
    "        x = self.bn(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.quant(x)\n",
    "\n",
    "        x = self.avgpool(x)\n",
    "        x = x.view(x.size(0), -1)\n",
    "        x = self.fc(x)\n",
    "        x = self.quant_half(x)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = PreResNet(act_error_quant)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' # change device to 'cpu' if you want to run this example on cpu\n",
    "model = model.to(device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now use the low-precision optimizer wrapper to help define the quantization of weight, gradient, momentum, and gradient accumulator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)\n",
    "optimizer = OptimLP(optimizer,\n",
    "                    weight_quant=weight_quant,\n",
    "                    grad_quant=grad_quant,\n",
    "                    momentum_quant=momentum_quant,\n",
    "                    acc_quant=acc_quant,\n",
    "                    grad_scaling=1/1000\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can reuse common training scripts without any extra codes to handle quantization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_epoch(loader, model, criterion, optimizer=None, phase=\"train\"):\n",
    "    assert phase in [\"train\", \"eval\"], \"invalid running phase\"\n",
    "    loss_sum = 0.0\n",
    "    correct = 0.0\n",
    "\n",
    "    if phase==\"train\": model.train()\n",
    "    elif phase==\"eval\": model.eval()\n",
    "\n",
    "    ttl = 0\n",
    "    with torch.autograd.set_grad_enabled(phase==\"train\"):\n",
    "        for i, (input, target) in tqdm(enumerate(loader), total=len(loader)):\n",
    "            input = input.to(device=device)\n",
    "            target = target.to(device=device)\n",
    "            output = model(input)\n",
    "            loss = criterion(output, target)\n",
    "            loss_sum += loss.cpu().item() * input.size(0)\n",
    "            pred = output.data.max(1, keepdim=True)[1]\n",
    "            correct += pred.eq(target.data.view_as(pred)).sum()\n",
    "            ttl += input.size()[0]\n",
    "\n",
    "            if phase==\"train\":\n",
    "                loss = loss * 1000\n",
    "                optimizer.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "\n",
    "    correct = correct.cpu().item()\n",
    "    return {\n",
    "        'loss': loss_sum / float(ttl),\n",
    "        'accuracy': correct / float(ttl) * 100.0,\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Begin the training process just as usual. Enjoy!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 391/391 [00:14<00:00, 26.41it/s]\n",
      "100%|██████████| 79/79 [00:01<00:00, 78.18it/s]\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1):\n",
    "    train_res = run_epoch(loaders['train'], model, F.cross_entropy,\n",
    "                                optimizer=optimizer, phase=\"train\")\n",
    "    test_res = run_epoch(loaders['test'], model, F.cross_entropy,\n",
    "                                optimizer=optimizer, phase=\"eval\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'loss': 1.6471979439544677, 'accuracy': 37.566}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'loss': 1.5749474658966065, 'accuracy': 43.63}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_res"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
